import copy

from atorch.modules.distributed_modules.modules_registry import (
    _REPLACE_SPECS,
    _SHARD_MAP,
    _SHARDABLE_OPERATORS,
    _SHARDED_OPERATORS,
)


class BaseTensorParallelPlanner:
    def __init__(self, strategy_name="base_strategy"):
        self.name = strategy_name
        self.shardable_operators_names = dict((v, k) for k, v in _SHARDABLE_OPERATORS.items())

    def generate_sharding_plan(
        self, model, graph, sharding_specs, tensor_shapes, device_topo, optimizer=None, **kwargs
    ):
        """Generates a sharding plan for the model. This implements the simplest strategy:
        match every shardable operator and replacement it with its first distributed implementation.

        All distributed operators will be distributed over the process_group 'tensor', which is created
        externally. All operators should respect the order of ranks of the process_group 'tensor'
        (no reshuffle on ranks), so ranks are set to None.
        group and ranks can be retreived via atorch.distributed.parallel_group_and_ranks('tensor')

        Overriding this method for other strategies. For example, this method can also get a topology of the
        GPU cluster, and create individual process groups for each operator.
        The group/ranks assignment can be specified in replaced_spec.
        Instruction on process group creation can be specified in process_groups.

        Args:
            model (torch.nn.Module): The base model to be sharded.
            graph (torch.fx.graph): The fx graph corresponding to the model. Graph is generated by
                TensorParallelOptimization, this is to ensure TensorParallelOptimization will be dealing with
                the same graph during tune and transform. Because graph cannot appear in config, it has to be
                traced again.
            sharding_specs (dict): a dict of sharding specs, specifying how inputs/outputs of a node are sharded.
                keys being the name of each node in the graph,
                values are also a dict, with two keys: 'input_spec', 'output_spec'
                    input_spec is itself a dict: with keys being the input node name,
                        values being the input node's sharding spec
                    output_spec is simple the sharding spec for the node.
                ShardingSpec object is the building block for sharding specs of nodes.
                    if a node's output is a torch.Tensor, then the output_spec is the the corresponding ShardingSpec
                    if a node's output is a dict of Tensors, then output_spec is a dict of corresponding ShardingSpecs
                    similarly for other iterable output
                    sharding spec being None means the target is not a Tensor (not to be sharded)

                Example:
                    >>> sharding_specs = {
                    >>>     'input_': {
                    >>>         'input_spec': None,
                    >>>         'output_spec: MeshShardingSpec(dims=(0,), group='Model', ranks=None)
                    >>>     }
                    >>>     'layer_1': {
                    >>>         'input_spec': {
                    >>>             'input_': MeshShardingSpec(dims=(0,), group='Model', ranks=None)
                    >>>         },
                    >>>         'output_spec': MeshShardingSpec(dims=(0,), group='Model', ranks=None)
                    >>>     }
                    >>>     'layer_2': {
                    >>>         'input_spec': {
                    >>>             'input_': MeshShardingSpec(dims=(0,), group='Model', ranks=None)
                    >>>         },
                    >>>         'output_spec': (
                    >>>             MeshShardingSpec(dims=(0,), group='Model', ranks=None),
                    >>>             MeshShardingSpec(dims=(0,), group='Model', ranks=None)
                    >>>         )
                    >>>     }
                    >>>     'layer_3': {
                    >>>         'input_spec': {
                    >>>             'layer_1': MeshShardingSpec(dims=(0,), group='Model', ranks=None),
                    >>>             'layer_2': (
                    >>>                 MeshShardingSpec(dims=(0,), group='Model', ranks=None),
                    >>>                 MeshShardingSpec(dims=(0,), group='Model', ranks=None)
                    >>>             )
                    >>>         },
                    >>>         'output_spec': MeshShardingSpec(dims=(0,), group='Model', ranks=None)
                    >>>     }
                    >>>     'output': {
                    >>>         'input_spec': {'add': MeshShardingSpec(dims=(0,), group='Model', ranks=None)},
                    >>>         'output_spec': MeshShardingSpec(dims=(0,), group='Model', ranks=None)
                    >>>     }
                    >>> }
            tensor_shapes: shapes of the output of each nodes, to be used for inferring memory requirement
            device_topo: A DeviceTopology oject corresponding to physical devices on which to distribute the model.

        Returns:
            replacement_map (dict): a map that describes how to shard nodes. keys are node names, values are the names
                of distributed implementation of the nodes.

                Example:
                    >>> replacement_map = {
                    >>>     'layer_1': 'RowParallelLinear',
                    >>>     'layer_2': 'RowParallelLinear
                    >>> }

            replaced_specs (dict): the sharding spec of each node after applying the replacement_map. This is to be
                used as a guidence on where to insert resharding operators

            process_group_assignment (dict): assignment of each node to a process group.
                Keys are node names, values are dicts describing the process group and the ranks.
                In most cases, group can be just a str (like Model), ranks can be None.
                For finest control, we can assign a ProcessGroup to the group attribute, a list of ranks to ranks.
                Specifying ranks gives us a control of the topology of the parallel model.

                Example:
                    >>> process_group_assignment = {
                    >>>     'layer_1': {'group': 'model', 'ranks': None},
                    >>>     'layer_2': {'group': 'model', 'ranks': None}
                    >>> }

            changed_local_nodes (set): A set of node names. Nodes in this set are operators that works seemlessly
                over all different sharding_specs. These nodes does not have to be replaced,
                but the output and input of these nodes are sharded in a different way than in sharding_specs

            process_groups (dict): FIXME subject to change. A dict that describes how to create process_groups.
        """
        replaced_specs = copy.deepcopy(sharding_specs)
        replacement_map = dict()
        process_group_assignment = dict()

        ranks = device_topo.get_device_ranks()

        for node in graph.nodes:
            if node.op == "call_module" or node.op == "call_function":
                orig_target = node.target
                node_target = type(model.get_submodule(orig_target)) if node.op == "call_module" else node.target
                target_name = self.shardable_operators_names.get(node_target, "")
                if target_name in _SHARDABLE_OPERATORS.keys():
                    # put the node on process_group "tensor"
                    sharded_op = _SHARDED_OPERATORS[_SHARD_MAP[target_name][0]]
                    shardable = True
                    if node.op == "call_module":
                        shardable = sharded_op.orig_module_shardable(model.get_submodule(orig_target), ranks)
                    if shardable:
                        replacement_map[node.name] = _SHARD_MAP[target_name][0]
                        process_group_assignment[node.name] = dict()
                        process_group_assignment[node.name]["group"] = "tensor"
                        process_group_assignment[node.name]["ranks"] = None

                if node.name in replacement_map:
                    (
                        replaced_specs[node.name]["input_spec"],
                        replaced_specs[node.name]["output_spec"],
                    ) = _REPLACE_SPECS[replacement_map[node.name]](
                        replaced_specs[node.name]["input_spec"],
                        replaced_specs[node.name]["output_spec"],
                        group=process_group_assignment[node.name]["group"],
                        ranks=process_group_assignment[node.name]["ranks"],
                    )

        # FIXME for operators like +/-/*, there might be problem of shape broadcasting
        # Until this is handled/tested, do not propagate through the local operators
        changed_local_nodes = set()

        # FIXME replaced_specs/process_group_assignment may contain ProcessGroup object, which cannot be serialized
        best_config = {
            "replacement_map": replacement_map,
            "replaced_specs": replaced_specs,
            "changed_local_nodes": changed_local_nodes,
            "process_group_assignment": process_group_assignment,
            "process_groups": dict(),
        }

        return best_config
